#!/usr/bin/env python3

from urllib import request
import json
import operator
import time
import logging
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

AGGREGATORS = ("avg", "count", "dev", "diff",
               "ep50r3", "ep50r7", "ep75r3", "ep75r7", "ep90r3", "ep90r7", "ep95r3",
               "ep95r7", "ep99r3", "ep99r7", "ep999r3", "ep999r7",
               "mimmin", "mimmax", "min", "max", "none",
               "p50", "p75", "p90", "p95", "p99", "p999",
               "sum", "zimsum")
FILL_POLICIES = ("none", "nan", "null", "zero")
METHODS = ("gt", "ge", "lt", "le", "eq", "ne")
ALARMS = ("warn", "crit")
log = logging.getLogger("repair-tsd")
log.setLevel(logging.INFO)
ch = logging.StreamHandler()
logformat = "%(message)s"
formatter = logging.Formatter(logformat)
ch.setFormatter(formatter)
log.addHandler(ch)


def _get_metrics(query, timeout):
    """
    Actually get data from OpenTSDB

    :param str query: the query string
    :param int timeout: how long to wait for OpenTSDB to respond
    :returns: yields metrics from the resulting list one at a time
    :rtype: dict (generator)
    """
    try:
        res = request.urlopen(query, timeout=timeout)
        metrics = json.loads(res.read().decode("utf-8"))
    except Exception as e:
        log.error("Failed to collect metrics: {}".format(e))
        exit(1)
    for m in metrics:
        yield m


def build_query(args):
    """
    Format the query we'll be sending to OpenTSDB

    :param dict args: All arguments needed to format query
    :returns: formatted query
    :rtype: str
    """
    if args.ssl:
        query = "https"
    else:
        query = "http"
    query += "://{}:{}/api/query?".format(args.host, args.port)
    query += "start={}s-ago&noAnnotations=true&m={}:".format(args.duration,
                                                             args.aggregator)
    if args.rate:
        query += "rate"
        if args.rate_counter or args.rate_reset_value:
            if args.rate_counter:
                query += "{counter,,"
            else:
                query += "{,,"
            if args.rate_reset_value:
                query += "{}}".format(args.rate_reset_value)
            else:
                query += "}"
        query += ":"
    if args.downsample:
        query += "{}s-{}".format(args.downsample_window, args.downsample)
        if args.downsample_fill_policy:
            query += "-{}".format(args.downsample_fill_policy)
        query += ":"
    query += args.metric
    if args.tag:
        tags = ",".join(args.tag)
        query += "{"
        query += tags
        query += "}"
    return query


def build_comparisons(expressions):
    """
    Turn a string object like 'gt,100,crit' into a tuple that
    python can use to evaluate state.
    Also ensure critical checks are put first in the list
    so we don't evaluate a datapoint as WARNING that should
    be CRITICAL

    :param list expressions: all expression strings
    :returns: formatted expressions
    :rtype: list
    """
    comparisons = []
    for expression in expressions:
        comparator, value, alarm = expression.split(",")
        if comparator not in METHODS:
            log.error("Invalid comparison method.")
            exit(1)
        if alarm not in ALARMS:
            log.error("Invalid alarm type.")
            exit(1)
        try:
            value = float(value)
        except ValueError:
            log.error("Alarm value must be a number.")
            exit(1)
        comparator = operator.__dict__[comparator]
        comparisons.append((comparator, value, alarm))
    # Ensure we check criticals first, since all comparisons are ORed
    sorted_comp = []
    for comp in comparisons:
        if comp[2] == "crit":
            sorted_comp.insert(0, comp)
        else:
            sorted_comp.append(comp)
    return sorted_comp


def _process_metric(m, args, comparisons, now):
    """
    Evaluate a single metric from OpenTSDB.
    In this case, a metric is a object containing a list
    of tuples of (ts, value) and a separate group of tags related
    to the metric.

    :param dict m: the actual metric data
    :param dict args: all arguments needed to perform evaluations
    :param list comparisons: all comparison tuples
    :param float now: the current time (generated by time.time())
    :returns: object describing the metric evaluated and its state
    :rtype: dict
    """
    value_count = len(m["dps"])
    mresult = {"crit": 0, "crit_alarm": False, "warn_alarm": False, "warn": 0,
               "crit_percent": 0, "warn_percent": 0, "empty": False, "metric_avg": 0}
    if args.tag:
        keys = [t.split("=")[0] for t in args.tag]
        mresult["tags"] = [v for k, v in m["tags"].items() if k in keys]
    if value_count < 1:
        mresult["empty"] = True
        return mresult

    avglist = []
    for ts, d in m["dps"].items():
        # handle out-of-time metrics
        if args.ignore_recent:
            try:
                ts = float(ts)
            except ValueError:
                log.error("Bad timestamp for {}: {}".format(",".join(mresult["tags"]), ts))
                mresult["crit_alarm"] = True
                break
            delta = now - ts
            if delta >= args.ignore_recent:
                log.debug("Timestamp outside evaluation range: {}".format(ts))
                continue
        avglist.append(d)
        for comparison in comparisons:
            comparator, value, alarm = comparison
            if comparator(d, value):
                mresult[alarm] += 1
                break
    mresult["metric_avg"] = sum(avglist)/len(avglist)
    mresult["crit_percent"] = mresult["crit"] / value_count * 100
    mresult["warn_percent"] = mresult["warn"] / value_count * 100
    if mresult["crit"] > 0:
        if args.percent_over > 0 and mresult["crit_percent"] > args.percent_over:
            mresult["crit_alarm"] = True
        else:
            mresult["crit_alarm"] = True
    if mresult["warn"] > 0:
        if args.percent_over > 0 and mresult["warn_percent"] > args.percent_over:
            mresult["warn_alarm"] = True
        else:
            mresult["warn_alarm"] = True
    return mresult


def process_metrics(query, args, comparisons):
    """
    Because we may get multiple metric "groups" back (if a query like
    system.load5{host=*} was sent in) we need to evaluate each individual
    metric "group" that returns from _get_metrics(). This wrapper helps
    us do just that.
 
    :param str query: The query to send to OpenTSDB
    :param dict args: All potential evaluation arguments
    :param list comparisons: all comparison tuples to use for evaluating state
    :returns: yields each evaluated metric object as it compeletes
    :rtype: dict (generator)
    """
    now = time.time()
    for m in _get_metrics(query, args.timeout):
        yield _process_metric(m, args, comparisons, now)


def cli_opts():
    parser = ArgumentParser(description="check tsd query",
                            formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument("-H", "--host", default="localhost", type=str,
                        help="host to check for stats")
    parser.add_argument("-p", "--port", default=4242, type=int,
                        help="port to check for stats")
    parser.add_argument("-m", "--metric", required=True, type=str,
                        help="Metric to query.")
    parser.add_argument("-t", "--tag", action="append", default=[],
                        help="Tags to filter the metric on.")
    parser.add_argument("-d", "--duration", type=int, default=3600,
                        help="How far back to look for data.")
    parser.add_argument("-D", "--downsample", default=None,
                        help="Downsample function", choices=AGGREGATORS)
    parser.add_argument("-W", "--downsample-window", type=int, default=60,
                        help="Window size over which to downsample.")
    parser.add_argument("-F", "--downsample-fill-policy", default=None,
                        help="Downsample Fill Policies", choices=FILL_POLICIES)
    parser.add_argument("-a", "--aggregator", default="sum",
                        help="Aggregation method", choices=AGGREGATORS)
    parser.add_argument("-r", "--rate", default=False,
                        action="store_true", help="Use rate value as comparison operand.")
    parser.add_argument("--rate-counter", default=False,
                        action="store_true", help="Use rate counter")
    parser.add_argument("--rate-reset-value", default=0,
                        type=int, help="rate reset value")
    parser.add_argument("-e", "--expression", action="append", required=True,
                        help="Comparison expression. e.g. gt,100,warn (multiple allowed)\n"
                        "Allowed methods: {}\nAllowed alarms: {}".format(",".join(METHODS), ",".join(ALARMS)))
    parser.add_argument("-I", "--ignore-recent", default=0, type=int,
                        help="Ignore data points that are that >= seconds ago.")
    parser.add_argument("-P", "--percent-over", dest="percent_over", default=0,
                        type=float, help="Only alarm if PERCENT of the data"
                        " points violate the threshold.")
    parser.add_argument("-S", "--ssl", default=False, action="store_true",
                        help="Make queries to OpenTSDB via SSL (https)")
    parser.add_argument("-T", "--timeout", type=int, default=30,
                        help="How long to wait for the response from TSD.")
    parser.add_argument("-A", "--alarm-empty", default=False,
                        action="store_true", help="Alert when an emtpy series returns")
    parser.add_argument("--debug", default=False,
                        action="store_true", help="Verbose logging")
    return parser.parse_args()


def main():
    args = cli_opts()
    if args.debug:
        log.setLevel(logging.DEBUG)
    if args.percent_over > 100 or args.percent_over < 0:
        log.error("Percentage over must be a value from 0-100: {}".format(args.percent_over))
        exit(1)
    if args.downsample_window < 0:
        log.error("Downsample window must be positive: {}".format(args.percent_over))
        exit(1)
    if args.downsample_window < 0:
        log.error("Downsample window must be positive: {}".format(args.percent_over))
        exit(1)
    comparisons = build_comparisons(args.expression)
    query = build_query(args)

    crit = False
    warn = False
    crits = []
    warns = []
    total = []
    for r in process_metrics(query, args, comparisons):
        total.append(r["tags"])
        if not r["crit_alarm"] and not r["warn_alarm"]:
            continue
        if args.alarm_empty and r["empty"]:
            log.info("{} => no data returned in range.".format(",".join(r["tags"])))
            crits.append(r["tags"])
            crit = True
            continue
        if r["crit_alarm"]:
            crits.append(r["tags"])
            crit = True
        elif r["warn_alarm"]:
            warns.append(r["tags"])
            warn = True
        alerts = r["crit"] + r["warn"]
        perc = r["crit_percent"] + r["warn_percent"]
        log.info("{} => alarmed {} times in range. ({}% alarms). Avg. Value: {}".format(",".join(r["tags"]),
                                                                                        alerts, perc, r["metric_avg"]))
    log.info("{} total metrics processed".format(len(total)))
    crit_count = len(crits)
    warn_count = len(warns)
    if crit_count > 0:
        log.info("{} Critical Alarms".format(crit_count))
    if warn_count > 0:
        log.info("{} Warning Alarms".format(warn_count))
    if crit:
        exit(2)
    elif warn:
        exit(1)


if __name__ == "__main__":
    main()
